{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ndo4ERqnwQOU"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "MTKwbguKwT4R"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xfNT-mlFwxVM"
      },
      "source": [
        "# 컨볼루셔널 변이형 오토인코더"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0TD5ZrvEMbhZ"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td><a target=\"_blank\" href=\"https://www.tensorflow.org/tutorials/generative/cvae\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\"> TensorFlow.org에서보기</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/generative/cvae.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\"> Google Colab에서 실행</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/tutorials/generative/cvae.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\"> GitHub에서 소스보기</a></td>\n",
        "  <td><a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/tutorials/generative/cvae.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\"> 노트북 다운로드</a></td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ITZuApL56Mny"
      },
      "source": [
        "이 노트북은 MNIST 데이터세트에서 변이형 오토인코더(VAE, Variational Autoencoder)를 훈련하는 방법을 보여줍니다([1](https://arxiv.org/abs/1312.6114) , [2](https://arxiv.org/abs/1401.4082)). VAE는 오토인코더의 확률론적 형태로, 높은 차원의 입력 데이터를 더 작은 표현으로 압축하는 모델입니다. 입력을 잠재 벡터에 매핑하는 기존의 오토인코더와 달리 VAE는 입력 데이터를 가우스 평균 및 분산과 같은 확률 분포의 매개변수에 매핑합니다. 이 방식은 연속적이고 구조화된 잠재 공간을 생성하므로 이미지 생성에 유용합니다.\n",
        "\n",
        "![CVAE 이미지 잠재 공간](images/cvae_latent_space.jpg)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e1_Y75QXJS6h"
      },
      "source": [
        "## 설정"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "P-JuIu2N_SQf"
      },
      "outputs": [],
      "source": [
        "!pip install tensorflow-probability\n",
        "\n",
        "# to generate gifs\n",
        "!pip install imageio\n",
        "!pip install git+https://github.com/tensorflow/docs"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YfIk2es3hJEd"
      },
      "outputs": [],
      "source": [
        "from IPython import display\n",
        "\n",
        "import glob\n",
        "import imageio\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import PIL\n",
        "import tensorflow as tf\n",
        "import tensorflow_probability as tfp\n",
        "import time"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iYn4MdZnKCey"
      },
      "source": [
        "## MNIST 데이터세트 로드하기\n",
        "\n",
        "각 MNIST 이미지는 원래 각각 0-255 사이인 784개의 정수로 구성된 벡터이며 픽셀 강도를 나타냅니다. 모델에서 Bernoulli 분포를 사용하여 각 픽셀을 모델링하고 데이터세트를 정적으로 이진화합니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "a4fYMGxGhrna"
      },
      "outputs": [],
      "source": [
        "(train_images, _), (test_images, _) = tf.keras.datasets.mnist.load_data()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NFC2ghIdiZYE"
      },
      "outputs": [],
      "source": [
        "def preprocess_images(images):\n",
        "  images = images.reshape((images.shape[0], 28, 28, 1)) / 255.\n",
        "  return np.where(images > .5, 1.0, 0.0).astype('float32')\n",
        "\n",
        "train_images = preprocess_images(train_images)\n",
        "test_images = preprocess_images(test_images)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "S4PIDhoDLbsZ"
      },
      "outputs": [],
      "source": [
        "train_size = 60000\n",
        "batch_size = 32\n",
        "test_size = 10000"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PIGN6ouoQxt3"
      },
      "source": [
        "## *tf.data*를 사용하여 데이터 배치 및 셔플 처리하기"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-yKCCQOoJ7cn"
      },
      "outputs": [],
      "source": [
        "train_dataset = (tf.data.Dataset.from_tensor_slices(train_images)\n",
        "                 .shuffle(train_size).batch(batch_size))\n",
        "test_dataset = (tf.data.Dataset.from_tensor_slices(test_images)\n",
        "                .shuffle(test_size).batch(batch_size))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "THY-sZMiQ4UV"
      },
      "source": [
        "## *tf.keras.Sequential*을 사용하여 인코더 및 디코더 네트워크 정의하기\n",
        "\n",
        "VAE 예제에서는 인코더 및 디코더 네트워크에 두 개의 작은 ConvNet을 사용합니다. 문헌에서 이들 네트워크는 각각 추론/인식 및 생성 모델로도 지칭됩니다. 구현을 단순화하기 위해 `tf.keras.Sequential`을 사용합니다. 다음 설명에서 $x$와 $z$는 각각 관측 값과 잠재 변수를 나타냅니다.\n",
        "\n",
        "### 인코더 네트워크\n",
        "\n",
        "이것은 근사 사후 분포 $q(z|x)$를 정의합니다. 이 분포는 관측 값을 입력으로 받고 잠재 표현 $z$의 조건부 분포를 지정하기 위한 매개변수 세트를 출력합니다. 이 예에서는 분포를 대각선 가우스로 간단히 모델링하고, 네트워크는 인수 분해된 가우스의 평균 및 로그-분산 매개변수를 출력합니다. 수치 안정성을 위해 분산을 직접 출력하지 않고 로그-분산을 출력합니다.\n",
        "\n",
        "### 디코더 네트워크\n",
        "\n",
        "잠복 샘플 $z$를 입력으로 사용하여 관측 값의 조건부 분포에 대한 매개변수를 출력하는 관측 값 $p(x|z)$의 조건부 분포를 정의합니다. 잠재 이전 분포 $p(z)$를 단위 가우스로 모델링합니다.\n",
        "\n",
        "### 재매개변수화 트릭\n",
        "\n",
        "훈련 중에 디코더에 대해 샘플 $z$를 생성하기 위해, 입력 관측 값 $x$가 주어졌을 때 인코더에 의해 출력된 매개변수로 정의된 잠재 분포로부터 샘플링할 수 있습니다. 그러나 역전파가 무작위 노드를 통해 흐를 수 없기 때문에 이 샘플링 작업에서 병목 현상이 발생합니다.\n",
        "\n",
        "이를 해결하기 위해 재매개변수화 트릭을 사용합니다. 이 예에서는 디코더 매개변수와 다른 매개변수 $\\epsilon$을 다음과 같이 사용하여 $z$를 근사시킵니다.\n",
        "\n",
        "$$z = \\mu + \\sigma \\odot \\epsilon$$\n",
        "\n",
        "여기서 $\\mu$ 및 $\\sigma$는 각각 가우스 분포의 평균 및 표준 편차를 나타냅니다. 이들은 디코더 출력에서 파생될 수 있습니다. $\\epsilon$은 $z$의 무질서도를 유지하는 데 사용되는 무작위 노이즈로 생각할 수 있습니다. 표준 정규 분포에서 $\\epsilon$을 생성합니다.\n",
        "\n",
        "잠재 변수 $z$는 이제 $\\mu$, $\\sigma$ 및 $\\epsilon$의 함수에 의해 생성되며, 이를 통해 모델은 각각 $\\mu$ 및 $\\sigma$를 통해 인코더의 그래디언트를 역전파하면서 $epsepson$를 통해 무질서도를 유지할 수 있습니다.\n",
        "\n",
        "### 네트워크 아키텍처\n",
        "\n",
        "인코더 네트워크의 경우 두 개의 컨볼루셔널 레이어, 그리고 이어서 완전히 연결된 레이어를 사용합니다. 디코더 네트워크에서 완전히 연결된 레이어와 그 뒤에 세 개의 컨볼루션 전치 레이어(일부 컨텍스트에서는 디컨볼루셔널 레이어라고도 함)를 사용하여 이 아키텍처를 미러링합니다. 미니 배치 사용으로 인한 추가 무질서도가 샘플링의 무질서에 더해져 불안정성을 높일 수 있으므로 VAE 훈련 시 배치 정규화를 사용하지 않는 것이 일반적입니다.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VGLbvBEmjK0a"
      },
      "outputs": [],
      "source": [
        "class CVAE(tf.keras.Model):\n",
        "  \"\"\"Convolutional variational autoencoder.\"\"\"\n",
        "\n",
        "  def __init__(self, latent_dim):\n",
        "    super(CVAE, self).__init__()\n",
        "    self.latent_dim = latent_dim\n",
        "    self.encoder = tf.keras.Sequential(\n",
        "        [\n",
        "            tf.keras.layers.InputLayer(input_shape=(28, 28, 1)),\n",
        "            tf.keras.layers.Conv2D(\n",
        "                filters=32, kernel_size=3, strides=(2, 2), activation='relu'),\n",
        "            tf.keras.layers.Conv2D(\n",
        "                filters=64, kernel_size=3, strides=(2, 2), activation='relu'),\n",
        "            tf.keras.layers.Flatten(),\n",
        "            # No activation\n",
        "            tf.keras.layers.Dense(latent_dim + latent_dim),\n",
        "        ]\n",
        "    )\n",
        "\n",
        "    self.decoder = tf.keras.Sequential(\n",
        "        [\n",
        "            tf.keras.layers.InputLayer(input_shape=(latent_dim,)),\n",
        "            tf.keras.layers.Dense(units=7*7*32, activation=tf.nn.relu),\n",
        "            tf.keras.layers.Reshape(target_shape=(7, 7, 32)),\n",
        "            tf.keras.layers.Conv2DTranspose(\n",
        "                filters=64, kernel_size=3, strides=2, padding='same',\n",
        "                activation='relu'),\n",
        "            tf.keras.layers.Conv2DTranspose(\n",
        "                filters=32, kernel_size=3, strides=2, padding='same',\n",
        "                activation='relu'),\n",
        "            # No activation\n",
        "            tf.keras.layers.Conv2DTranspose(\n",
        "                filters=1, kernel_size=3, strides=1, padding='same'),\n",
        "        ]\n",
        "    )\n",
        "\n",
        "  @tf.function\n",
        "  def sample(self, eps=None):\n",
        "    if eps is None:\n",
        "      eps = tf.random.normal(shape=(100, self.latent_dim))\n",
        "    return self.decode(eps, apply_sigmoid=True)\n",
        "\n",
        "  def encode(self, x):\n",
        "    mean, logvar = tf.split(self.encoder(x), num_or_size_splits=2, axis=1)\n",
        "    return mean, logvar\n",
        "\n",
        "  def reparameterize(self, mean, logvar):\n",
        "    eps = tf.random.normal(shape=mean.shape)\n",
        "    return eps * tf.exp(logvar * .5) + mean\n",
        "\n",
        "  def decode(self, z, apply_sigmoid=False):\n",
        "    logits = self.decoder(z)\n",
        "    if apply_sigmoid:\n",
        "      probs = tf.sigmoid(logits)\n",
        "      return probs\n",
        "    return logits"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0FMYgY_mPfTi"
      },
      "source": [
        "## 손실 함수 및 옵티마이저 정의하기\n",
        "\n",
        "VAE는 한계 로그-우도에 대한 ELBO(evidence lower bound)를 최대화하여 훈련합니다.\n",
        "\n",
        "$$\\log p(x) \\ge \\text{ELBO} = \\mathbb{E}_{q(z|x)}\\left[\\log \\frac{p(x, z)}{q(z|x)}\\right].$$\n",
        "\n",
        "실제로, 이 예상에 대한 단일 샘플 Monte Carlo 추정값을 최적화합니다.\n",
        "\n",
        "$$\\log p(x| z) + \\log p(z) - \\log q(z|x),$$ 여기서 $z$는 $q(z|x)$에서 샘플링됩니다.\n",
        "\n",
        "**참고**: KL 항을 분석적으로 계산할 수도 있지만 여기서는 단순화를 위해 Monte Carlo 예측 도구에 세 항을 모두 통합합니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iWCn_PVdEJZ7"
      },
      "outputs": [],
      "source": [
        "optimizer = tf.keras.optimizers.Adam(1e-4)\n",
        "\n",
        "\n",
        "def log_normal_pdf(sample, mean, logvar, raxis=1):\n",
        "  log2pi = tf.math.log(2. * np.pi)\n",
        "  return tf.reduce_sum(\n",
        "      -.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi),\n",
        "      axis=raxis)\n",
        "\n",
        "\n",
        "def compute_loss(model, x):\n",
        "  mean, logvar = model.encode(x)\n",
        "  z = model.reparameterize(mean, logvar)\n",
        "  x_logit = model.decode(z)\n",
        "  cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, labels=x)\n",
        "  logpx_z = -tf.reduce_sum(cross_ent, axis=[1, 2, 3])\n",
        "  logpz = log_normal_pdf(z, 0., 0.)\n",
        "  logqz_x = log_normal_pdf(z, mean, logvar)\n",
        "  return -tf.reduce_mean(logpx_z + logpz - logqz_x)\n",
        "\n",
        "\n",
        "@tf.function\n",
        "def train_step(model, x, optimizer):\n",
        "  \"\"\"Executes one training step and returns the loss.\n",
        "\n",
        "  This function computes the loss and gradients, and uses the latter to\n",
        "  update the model's parameters.\n",
        "  \"\"\"\n",
        "  with tf.GradientTape() as tape:\n",
        "    loss = compute_loss(model, x)\n",
        "  gradients = tape.gradient(loss, model.trainable_variables)\n",
        "  optimizer.apply_gradients(zip(gradients, model.trainable_variables))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Rw1fkAczTQYh"
      },
      "source": [
        "## 훈련하기\n",
        "\n",
        "- 데이터세트에 대한 반복 실행을 시작합니다.\n",
        "- 반복하는 동안 매번 이미지를 인코더로 전달하여 근사적인 사후 $q(z|x)$의 평균 및 로그-분산 매개변수 세트를 얻습니다.\n",
        "- 그런 다음 $q(z|x)$에서 샘플링하기 위해 *재매개변수화 트릭*을 적용합니다.\n",
        "- 마지막으로, 생성된 분포 $p(x|z)$의 로짓을 얻기 위해 재매개변수화된 샘플을 디코더로 전달합니다.\n",
        "- **참고:** 훈련 세트에 60k 데이터 포인트와 테스트 세트에 10k 데이터 포인트가 있는 keras에 의해 로드된 데이터세트를 사용하기 때문에 테스트세트에 대한 결과 ELBO는 Larochelle MNIST의 동적 이진화를 사용하는 문헌에서 보고된 결과보다 약간 높습니다.\n",
        "\n",
        "### 이미지 생성하기\n",
        "\n",
        "- 훈련을 마쳤으면 이미지를 생성할 차례입니다.\n",
        "- 우선, 단위 가우스 사전 분포 $p(z)$에서 잠재 벡터 세트를 샘플링합니다.\n",
        "- 그러면 생성기가 잠재 샘플 $z$를 관측 값의 로짓으로 변환하여 분포 $p(x|z)$를 제공합니다.\n",
        "- 여기서 Bernoulli 분포의 확률을 플롯합니다.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NS2GWywBbAWo"
      },
      "outputs": [],
      "source": [
        "epochs = 10\n",
        "# set the dimensionality of the latent space to a plane for visualization later\n",
        "latent_dim = 2\n",
        "num_examples_to_generate = 16\n",
        "\n",
        "# keeping the random vector constant for generation (prediction) so\n",
        "# it will be easier to see the improvement.\n",
        "random_vector_for_generation = tf.random.normal(\n",
        "    shape=[num_examples_to_generate, latent_dim])\n",
        "model = CVAE(latent_dim)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RmdVsmvhPxyy"
      },
      "outputs": [],
      "source": [
        "def generate_and_save_images(model, epoch, test_sample):\n",
        "  mean, logvar = model.encode(test_sample)\n",
        "  z = model.reparameterize(mean, logvar)\n",
        "  predictions = model.sample(z)\n",
        "  fig = plt.figure(figsize=(4, 4))\n",
        "\n",
        "  for i in range(predictions.shape[0]):\n",
        "    plt.subplot(4, 4, i + 1)\n",
        "    plt.imshow(predictions[i, :, :, 0], cmap='gray')\n",
        "    plt.axis('off')\n",
        "\n",
        "  # tight_layout minimizes the overlap between 2 sub-plots\n",
        "  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))\n",
        "  plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "swCyrbqQQ-Ri"
      },
      "outputs": [],
      "source": [
        "# Pick a sample of the test set for generating output images\n",
        "assert batch_size >= num_examples_to_generate\n",
        "for test_batch in test_dataset.take(1):\n",
        "  test_sample = test_batch[0:num_examples_to_generate, :, :, :]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2M7LmLtGEMQJ"
      },
      "outputs": [],
      "source": [
        "generate_and_save_images(model, 0, test_sample)\n",
        "\n",
        "for epoch in range(1, epochs + 1):\n",
        "  start_time = time.time()\n",
        "  for train_x in train_dataset:\n",
        "    train_step(model, train_x, optimizer)\n",
        "  end_time = time.time()\n",
        "\n",
        "  loss = tf.keras.metrics.Mean()\n",
        "  for test_x in test_dataset:\n",
        "    loss(compute_loss(model, test_x))\n",
        "  elbo = -loss.result()\n",
        "  display.clear_output(wait=False)\n",
        "  print('Epoch: {}, Test set ELBO: {}, time elapse for current epoch: {}'\n",
        "        .format(epoch, elbo, end_time - start_time))\n",
        "  generate_and_save_images(model, epoch, test_sample)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "P4M_vIbUi7c0"
      },
      "source": [
        "### 마지막 훈련 epoch에서 생성된 이미지 표시하기"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WfO5wCdclHGL"
      },
      "outputs": [],
      "source": [
        "def display_image(epoch_no):\n",
        "  return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5x3q9_Oe5q0A"
      },
      "outputs": [],
      "source": [
        "plt.imshow(display_image(epoch))\n",
        "plt.axis('off')  # Display images"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NywiH3nL8guF"
      },
      "source": [
        "### 저장된 모든 이미지의 애니메이션 GIF 표시하기"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IGKQgENQ8lEI"
      },
      "outputs": [],
      "source": [
        "anim_file = 'cvae.gif'\n",
        "\n",
        "with imageio.get_writer(anim_file, mode='I') as writer:\n",
        "  filenames = glob.glob('image*.png')\n",
        "  filenames = sorted(filenames)\n",
        "  for filename in filenames:\n",
        "    image = imageio.imread(filename)\n",
        "    writer.append_data(image)\n",
        "  image = imageio.imread(filename)\n",
        "  writer.append_data(image)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2ZqAEtdqUmJF"
      },
      "outputs": [],
      "source": [
        "import tensorflow_docs.vis.embed as embed\n",
        "embed.embed_file(anim_file)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PeunRU6TSumT"
      },
      "source": [
        "### 잠재 공간에서 숫자의 2D 형태 표시하기\n",
        "\n",
        "아래 코드를 실행하면 여러 숫자 클래스의 연속 분포가 표시되며 각 숫자는 2D 잠재 공간에서 다른 숫자로 변형됩니다. 잠재 공간에 대한 표준 정규 분포를 생성하기 위해 [TensorFlow Probability](https://www.tensorflow.org/probability)를 사용합니다."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "code",
        "id": "mNcaaYPBS3mj"
      },
      "outputs": [],
      "source": [
        "def plot_latent_images(model, n, digit_size=28):\n",
        "  \"\"\"Plots n x n digit images decoded from the latent space.\"\"\"\n",
        "\n",
        "  norm = tfp.distributions.Normal(0, 1)\n",
        "  grid_x = norm.quantile(np.linspace(0.05, 0.95, n))\n",
        "  grid_y = norm.quantile(np.linspace(0.05, 0.95, n))\n",
        "  image_width = digit_size*n\n",
        "  image_height = image_width\n",
        "  image = np.zeros((image_height, image_width))\n",
        "\n",
        "  for i, yi in enumerate(grid_x):\n",
        "    for j, xi in enumerate(grid_y):\n",
        "      z = np.array([[xi, yi]])\n",
        "      x_decoded = model.sample(z)\n",
        "      digit = tf.reshape(x_decoded[0], (digit_size, digit_size))\n",
        "      image[i * digit_size: (i + 1) * digit_size,\n",
        "            j * digit_size: (j + 1) * digit_size] = digit.numpy()\n",
        "\n",
        "  plt.figure(figsize=(10, 10))\n",
        "  plt.imshow(image, cmap='Greys_r')\n",
        "  plt.axis('Off')\n",
        "  plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "F-ZG69QCZnGY"
      },
      "outputs": [],
      "source": [
        "plot_latent_images(model, 20)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HrJRef8Ln945"
      },
      "source": [
        "## 다음 단계\n",
        "\n",
        "이 튜토리얼에서는 TensorFlow를 사용하여 컨볼루셔널 변이형 오토인코더를 구현하는 방법을 설명했습니다.\n",
        "\n",
        "다음 단계로 네트워크 크기를 증가시켜 모델 출력의 개선을 시도할 수 있습니다. 예를 들어, 각 `Conv2D` 및`Conv2DTranspose` 레이어에 대한 `filter` 매개변수를 512로 설정할 수 있습니다. 최종 2D 잠재 이미지 플롯을 생성하기 위해 `latent_dim`을 2로 유지해야 합니다. 또한, 네트워크 크기가 증가할수록 훈련 시간이 늘어납니다.\n",
        "\n",
        "CIFAR-10과 같은 다른 데이터세트를 사용하여 VAE를 구현할 수도 있습니다.\n",
        "\n",
        "VAE는 여러 가지 스타일과 다양한 복잡도로 구현할 수 있습니다. 다음 자료에서 추가 구현을 찾을 수 있습니다.\n",
        "\n",
        "- [변이형 오토인코더(keras.io)](https://keras.io/examples/generative/vae/)\n",
        "- [\"사용자 정의 레이어 및 모델 작성\" 가이드의 VAE 예제(tensorflow.org)](https://www.tensorflow.org/guide/keras/custom_layers_and_models#putting_it_all_together_an_end-to-end_example)\n",
        "- [TFP 확률 레이어: 변이형 오토인코더](https://www.tensorflow.org/probability/examples/Probabilistic_Layers_VAE)\n",
        "\n",
        "VAE에 대한 자세한 내용을 알아보려면 [변이형 오토인코더 소개](https://arxiv.org/abs/1906.02691)를 참조하세요."
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "cvae.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
